# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions evalute the expression string.

This module includes the functions to evaluate the performance of the expression
strings generated by generative models. Please avoid using Python's built-in
eval() function due to safety reason. All the expression string evaluation
functions must use numpy_array_eval() function defined in this module to
evaluate the expression with numpy array arguments.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import ast

import numpy as np
import six


def divide_with_zero_divisor(dividend, divisor):
  """Returns 0 when divisor is zero.

  Args:
    dividend: Numpy array or scalar.
    divisor: Numpy array or scalar.

  Returns:
    Scalar if both inputs are scalar, numpy array otherwise.
  """
  # NOTE(leeley): The out argument should have the broadcasting shape of
  # (dividend, divisor) instead of shape of dividend solely.
  # Thus, in case dividend is scalar and divisor is numpy array, if there is an
  # zero element in divisor array, the output will still have the shape of
  # divisor array.
  # out argument cannot be omitted. The default np.true_divide will output empty
  # instead of zero at ~where.
  broadcast = np.ones(np.broadcast(dividend, divisor).shape)
  return np.true_divide(
      dividend,
      divisor,
      out=np.zeros_like(broadcast),
      where=divisor * broadcast != 0)


def power_with_zero_base(base, exponent):
  """Returns 0 when base is 0 and exponent is negative.

  The default power() in numpy will raise RuntimeWarning and return invalid
  value when base is zero and exponent is negative.

  Args:
    base: Numpy array or scalar.
    exponent: Numpy array or scalar.

  Returns:
    Scalar if both inputs are scalar, numpy array otherwise.
  """
  broadcast = np.ones(np.broadcast(base, exponent).shape)
  return np.power(base,
                  exponent,
                  out=np.zeros_like(broadcast),
                  where=np.logical_or(
                      base * broadcast != 0, exponent * broadcast >= 0))


_OPERATORS = {
    # Unary operator.
    ast.USub: np.negative,
    # Binary operators.
    ast.Add: np.add,
    ast.Sub: np.subtract,
    ast.Mult: np.multiply,
    ast.Div: divide_with_zero_divisor,
    ast.Pow: power_with_zero_base,
}

_CALLABLES = {
    'sin': np.sin,
    'cos': np.cos,
    'sqrt': np.sqrt,
    'exp': np.exp,
    'log': np.log,
    'log10': np.log10,
    'abs': np.abs,
    'add': np.add,
    'subtract': np.subtract,
    'multiply': np.multiply,
    'divide': np.divide,
}


def numpy_array_eval(string, callables=None, arguments=None):
  """Evaluates string with numpy array and whitelisted function calls.

  Python's built-in eval() function has safety issues.
  ast library has literal_eval() function safely evaluate an expression which
  only consists of the following Python literal structures: strings, numbers,
  tuples, lists, dicts, booleans, and None.
  See https://docs.python.org/2/library/ast.html#ast.literal_eval

  This function evaluates expression with numbers, numpy arrays,
  constant values, operations, and numpy function calls.

  Args:
    string: A string of expression.
    callables: An optional dictionary mapping a function name in expression
        string to a callable function. For example,
        {'sin': np.sin, 'sum': np.sum}.
        If not provided, defaults to _CALLABLES.
    arguments: An optional dictionary mapping an argument name in expression
        string to a number or numpy array. For example,
        {'a': 0.125, 'n': np.arange(5)}
        If not provided, defaults to {}.

  Returns:
    The evaluation of string.

  Raises:
    ValueError: Occurs if input callables or arguments is not a dictionary.
  """
  default_callables = _CALLABLES
  if callables is not None:
    if isinstance(callables, dict):
      default_callables.update(callables)
    else:
      raise ValueError('Input callables expected to be a dict.')

  if arguments is None:
    arguments = {}

  if not isinstance(arguments, dict):
    raise ValueError('Input arguments expected to be a dict.')

  node = ast.parse(string, mode='eval')
  if isinstance(node, ast.Expression):
    node = node.body

  def _eval(node):
    """Evaluates the node.

    Args:
      node: ast.AST node class.

    Returns:
      The evaluation of the parsing tree, the root of which is the input node.

    Raises:
      SyntaxError: Occurs if a argument in string is not in arguments, or a
          callable function in string is not in callables, or the input string
          is malformed.
    """
    if isinstance(node, ast.Num):
      return node.n
    elif isinstance(node, ast.UnaryOp):
      return _OPERATORS[type(node.op)](_eval(node.operand))
    elif isinstance(node, ast.BinOp):
      return _OPERATORS[type(node.op)](_eval(node.left), _eval(node.right))
    elif isinstance(node, ast.Name):
      if node.id in arguments:
        return arguments[node.id]
      else:
        raise SyntaxError('Unknown argument: %r' % node.id)
    elif isinstance(node, ast.Call):
      callable_name = node.func.id
      if callable_name not in default_callables:
        raise SyntaxError('Unknown callable: %r' % callable_name)
      return default_callables[callable_name](
          *[_eval(arg) for arg in node.args])
    else:
      raise SyntaxError('Malformed string: %r' % string)

  return _eval(node)


def evaluate_expression_strings_1d_grid(expression_strings,
                                        num_samples,
                                        num_grids,
                                        callables=None,
                                        arguments=None):
  """Evaluates a list of expression strings.

  Args:
    expression_strings: List of num_expressions strings, the expressions
        to evaluate.
    num_samples: Integer, number of samples to evaluate.
    num_grids: Integer, number of 1d grid points.
    callables: A dictionary mapping a function name in expression string to a
        callable function.
    arguments: A dictionary mapping an argument name in expression string to a
        number or numpy array. Numpy array must have shape
        [num_samples, num_grids].

  Raises:
    ValueError: Occurs if shape of a numpy array argument is not
        [num_samples, num_grids].
    ValueError: Occurs if the argument is not np.ndarray, int, or float.

  Returns:
    A numpy array with shape [num_expressions, num_samples, num_grids].
  """
  num_expressions = len(expression_strings)

  for symbol, value in six.iteritems(arguments):
    if isinstance(value, np.ndarray):
      if value.shape != (num_samples, num_grids):
        raise ValueError('The shape of %s is expected to be (%d, %d) '
                         'but got %s.'
                         % (symbol, num_samples, num_grids, str(value.shape)))
    elif not isinstance(value, (int, float)):
      raise ValueError('Argument should be np.ndarray, int, or float. '
                       'but got %s, %s.' % (symbol, type(value)))

  results = np.zeros(
      (num_expressions, num_samples, num_grids), dtype=np.float32)
  for i, expression_string in enumerate(expression_strings):
    results[i] = numpy_array_eval(expression_string,
                                  callables=callables,
                                  arguments=arguments)
  return results
